#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


"""
msmodelslim.utils.function_hijacker 模块的单元测试
"""

import unittest
from unittest.mock import patch, MagicMock

from msmodelslim.utils.function_hijacker import (
    _get_target_id, _create_kwargs_wrapper, FunctionHijacker,
    hijack_function, restore_function, restore_all_hijacked,
    get_original_function, _hijacker
)


def original_func_for_test(a, b=10, c=None):
    """用于测试的原始函数"""
    return a + b + (c or 0)


def replacement_func_for_test(a, b=10, c=None):
    """用于测试的替换函数"""
    return a * b * (c or 1)


def original_func_with_args_kwargs(*args, **kwargs):
    """用于测试带*args和**kwargs的原始函数"""
    return sum(args) + sum(kwargs.values())


def replacement_func_with_args_kwargs(**kwargs):
    """用于测试带*args和**kwargs的替换函数"""
    raw_args = kwargs.get('_raw_args', ())
    if raw_args:
        args_sum = sum([x for x in raw_args if isinstance(x, (int, float))])
    else:
        args_sum = sum([x for x in kwargs.get('args', ()) if isinstance(x, (int, float))])

    kwargs_sum = sum([v for v in kwargs.get('kwargs', {}).values() if isinstance(v, (int, float))])
    other_sum = sum([v for k, v in kwargs.items() if
                     k not in ['args', 'kwargs', '_raw_args'] and isinstance(v, (int, float))])

    return args_sum + kwargs_sum + other_sum


def replacement_func_for_inspection_error(**kwargs):
    """用于测试检查错误的替换函数"""
    raw_args = kwargs.get('_raw_args', ())
    numeric_args = [x for x in raw_args if isinstance(x, (int, float))]
    numeric_kwargs = [v for k, v in kwargs.items() if k != '_raw_args' and isinstance(v, (int, float))]
    return sum(numeric_args) + sum(numeric_kwargs)


def replacement_func_for_bind_error(**kwargs):
    """用于测试绑定错误的替换函数"""
    raw_args = kwargs.get('_raw_args', ())
    if len(raw_args) >= 2:
        return raw_args[0] + raw_args[1]
    else:
        return kwargs.get('a', 0) + kwargs.get('b', 0)


def replacement_func_for_complex_func(**kwargs):
    """用于测试复杂函数的替换函数"""
    raw_args = kwargs.get('_raw_args', ())
    kw_arg = kwargs.get('kw_arg', 1)
    extra = kwargs.get('extra', 0)

    numeric_raw_args = [x for x in raw_args if isinstance(x, (int, float))]

    if raw_args:
        first_arg = raw_args[0]
        return first_arg * kw_arg * (sum(numeric_raw_args) + 1) * (extra + 1)
    else:
        return 0


def complex_func_for_test(a: int, b: str = "default", *args, **kwargs) -> str:
    """用于测试复杂函数的原始函数"""
    return f"{a}-{b}"


def replacement_func_for_complex_test(**kwargs):
    """用于测试复杂函数的替换函数"""
    a = kwargs.get('a', 0)
    b = kwargs.get('b', 'default')
    return f"replaced-{a}-{b}"


class TestTargetId(unittest.TestCase):
    """测试_get_target_id函数"""

    def test_get_target_id_basic(self):
        """测试基本的target_id生成"""

        class TestContainer:
            def test_func(self):
                pass

        container = TestContainer()
        target = (container, 'test_func')
        target_id = _get_target_id(target)

        self.assertIsInstance(target_id, str)
        self.assertIn(f"{id(container)}:", target_id)
        self.assertIn(":test_func", target_id)

    def test_get_target_id_different_objects(self):
        """测试不同对象的target_id唯一性"""

        class TestContainer:
            def func1(self):
                pass

            def func2(self):
                pass

        container1 = TestContainer()
        container2 = TestContainer()

        target1 = (container1, 'func1')
        target2 = (container2, 'func1')
        target3 = (container1, 'func2')

        id1 = _get_target_id(target1)
        id2 = _get_target_id(target2)
        id3 = _get_target_id(target3)

        self.assertNotEqual(id1, id2)
        self.assertNotEqual(id1, id3)


class TestCreateKwargsWrapper(unittest.TestCase):
    """测试_create_kwargs_wrapper函数"""

    def test_wrapper_with_inspectable_function(self):
        """测试可检查函数的包装器"""
        wrapper = _create_kwargs_wrapper(original_func_for_test, replacement_func_for_test)

        result = wrapper(2, 3, 4)
        self.assertEqual(result, 24)

        result = wrapper(a=2, b=3, c=4)
        self.assertEqual(result, 24)

        result = wrapper(a=2, b=3)
        self.assertEqual(result, 6)

    def test_wrapper_with_args(self):
        """测试包含位置参数的包装器"""
        wrapper = _create_kwargs_wrapper(
            original_func_with_args_kwargs, replacement_func_with_args_kwargs
        )

        result = wrapper(1, 2, 3, x=4, y=5)
        self.assertEqual(result, 15)

    def test_wrapper_with_inspection_error(self):
        """测试检查时出错的情况"""
        original_func = MagicMock()
        original_func.__name__ = 'test_func'

        with patch('inspect.signature', side_effect=TypeError("Cannot inspect")):
            wrapper = _create_kwargs_wrapper(
                original_func, replacement_func_for_inspection_error
            )

            result = wrapper(x=1, y=2)
            self.assertEqual(result, 3)

    def test_wrapper_with_inspection_error_with_args(self):
        """测试检查时出错且包含位置参数的情况"""
        original_func = MagicMock()
        original_func.__name__ = 'test_func'

        with patch('inspect.signature', side_effect=TypeError("Cannot inspect")):
            wrapper = _create_kwargs_wrapper(
                original_func, replacement_func_for_inspection_error
            )

            result = wrapper(1, 2, 3, x=4, y=5)
            self.assertEqual(result, 15)

    def test_wrapper_with_bind_error(self):
        """测试绑定参数时出错的情况"""

        def original_func(a, b):
            return a + b

        wrapper = _create_kwargs_wrapper(original_func, replacement_func_for_bind_error)

        result = wrapper(1, 2, 3)
        self.assertEqual(result, 3)

    def test_wrapper_with_inspect_signature_error(self):
        """测试inspect.signature出错的情况"""

        def original_func(a, b):
            return a + b

        def replacement_func(**kwargs):
            return kwargs.get('a', 0) + kwargs.get('b', 0)

        with patch('inspect.signature', side_effect=ValueError("Cannot inspect")):
            wrapper = _create_kwargs_wrapper(original_func, replacement_func)

            result = wrapper(1, 2)
            self.assertIsInstance(result, int)

    def test_wrapper_with_no_args(self):
        """测试无参数函数的包装器"""

        def original_func_no_args():
            return 42

        def replacement_func_no_args():
            return 84

        wrapper = _create_kwargs_wrapper(original_func_no_args, replacement_func_no_args)

        result = wrapper()
        self.assertEqual(result, 84)


class TestFunctionHijacker(unittest.TestCase):
    """测试FunctionHijacker类"""

    def setUp(self):
        """每个测试方法前的设置"""
        self.hijacker = FunctionHijacker()

        class TestContainer:
            @staticmethod
            def original_func(x, y=10):
                return x + y

            @staticmethod
            def another_func(z):
                return z * 2

        self.container = TestContainer()

    def test_initialization(self):
        """测试初始化"""
        hijacker = FunctionHijacker()
        self.assertIsInstance(hijacker.original_functions, dict)
        self.assertIsInstance(hijacker.hijacked_targets, dict)
        self.assertEqual(len(hijacker.original_functions), 0)
        self.assertEqual(len(hijacker.hijacked_targets), 0)

    def test_hijack_function_basic(self):
        """测试基本的函数劫持"""

        def replacement_func(x, y=10):
            return x * y

        target = (self.container, 'original_func')
        self.hijacker.hijack_function(target, replacement_func)

        result = self.container.original_func(2, 3)
        self.assertEqual(result, 6)

        original_func = self.hijacker.get_original_function(target)
        self.assertIsNotNone(original_func)
        self.assertEqual(original_func(2, 3), 5)

    def test_hijack_function_update(self):
        """测试更新已劫持的函数"""

        def replacement_func1(x, y=10):
            return x * y

        def replacement_func2(x, y=10):
            return x ** y

        target = (self.container, 'original_func')

        self.hijacker.hijack_function(target, replacement_func1)
        result1 = self.container.original_func(2, 3)
        self.assertEqual(result1, 6)

        self.hijacker.hijack_function(target, replacement_func2)
        result2 = self.container.original_func(2, 3)
        self.assertEqual(result2, 8)

    def test_restore_function(self):
        """测试恢复函数"""

        def replacement_func(x, y=10):
            return x * y

        target = (self.container, 'original_func')

        self.hijacker.hijack_function(target, replacement_func)
        hijacked_result = self.container.original_func(2, 3)
        self.assertEqual(hijacked_result, 6)

        self.hijacker.restore_function(target)
        restored_result = self.container.original_func(2, 3)
        self.assertEqual(restored_result, 5)

    def test_restore_function_not_found(self):
        """测试恢复不存在的函数"""
        target = (self.container, 'nonexistent_func')

        self.hijacker.restore_function(target)

    def test_restore_all(self):
        """测试恢复所有函数"""

        def replacement_func1(x, y=10):
            return x * y

        def replacement_func2(z):
            return z ** 2

        target1 = (self.container, 'original_func')
        target2 = (self.container, 'another_func')

        self.hijacker.hijack_function(target1, replacement_func1)
        self.hijacker.hijack_function(target2, replacement_func2)

        self.assertEqual(self.container.original_func(2, 3), 6)
        self.assertEqual(self.container.another_func(4), 16)

        self.hijacker.restore_all()

        self.assertEqual(self.container.original_func(2, 3), 5)
        self.assertEqual(self.container.another_func(4), 8)

    def test_get_original_function(self):
        """测试获取原函数"""

        def replacement_func(x, y=10):
            return x * y

        target = (self.container, 'original_func')

        original = self.hijacker.get_original_function(target)
        self.assertIsNone(original)

        self.hijacker.hijack_function(target, replacement_func)
        original = self.hijacker.get_original_function(target)
        self.assertIsNotNone(original)
        self.assertEqual(original(2, 3), 5)

    def test_hijack_nonexistent_attribute(self):
        """测试劫持不存在的属性"""

        def replacement_func():
            return "replaced"

        target = (self.container, 'nonexistent_attr')
        with self.assertRaises(AttributeError):
            self.hijacker.hijack_function(target, replacement_func)

    def test_complex_function_with_kwargs(self):
        """测试复杂的函数劫持（带kwargs）"""

        def original_func(a, b=1, *args, **kwargs):
            numeric_args = [arg for arg in args if isinstance(arg, (int, float))]
            numeric_kwargs = [v for v in kwargs.values() if isinstance(v, (int, float))]
            return a + b + sum(numeric_args) + sum(numeric_kwargs)

        def replacement_func(**kwargs):
            a = kwargs.get('a', 0)
            b = kwargs.get('b', 1)
            args_tuple = kwargs.get('args', ())
            kwargs_dict = kwargs.get('kwargs', {})

            numeric_args = [x for x in args_tuple if isinstance(x, (int, float))]
            numeric_kwargs = [v for v in kwargs_dict.values() if isinstance(v, (int, float))]

            return a * b * (sum(numeric_args) + 1) * (sum(numeric_kwargs) + 1)

        self.container.original_func = original_func

        target = (self.container, 'original_func')
        self.hijacker.hijack_function(target, replacement_func)

        result = self.container.original_func(2, 3, 4, 5, x=6, y=7)
        self.assertEqual(result, 840)

    def test_hijack_builtin_function(self):
        """测试劫持内置函数（通过模拟）"""

        class MockBuiltin:
            def __init__(self):
                self.builtin_func = len

        mock_obj = MockBuiltin()

        def replacement_func(*args, **kwargs):
            return 42

        target = (mock_obj, 'builtin_func')
        self.hijacker.hijack_function(target, replacement_func)

        result = mock_obj.builtin_func([1, 2, 3])
        self.assertEqual(result, 42)

    def test_hijacker_with_no_hijacks(self):
        """测试没有劫持时的恢复操作"""
        self.hijacker.restore_all()

        original = self.hijacker.get_original_function(('dummy', 'dummy'))
        self.assertIsNone(original)


class TestGlobalFunctions(unittest.TestCase):
    """测试全局函数"""

    def setUp(self):
        """每个测试方法前的设置"""
        _hijacker.original_functions.clear()
        _hijacker.hijacked_targets.clear()

        class TestContainer:
            @staticmethod
            def original_func(x, y=10):
                return x + y

        self.container = TestContainer()

    def test_global_hijack_function(self):
        """测试全局hijack_function函数"""

        def replacement_func(x, y=10):
            return x * y

        target = (self.container, 'original_func')
        hijack_function(target, replacement_func)

        result = self.container.original_func(2, 3)
        self.assertEqual(result, 6)

    def test_global_restore_function(self):
        """测试全局restore_function函数"""

        def replacement_func(x, y=10):
            return x * y

        target = (self.container, 'original_func')

        hijack_function(target, replacement_func)
        self.assertEqual(self.container.original_func(2, 3), 6)

        restore_function(target)
        self.assertEqual(self.container.original_func(2, 3), 5)

    def test_global_restore_all_hijacked(self):
        """测试全局restore_all_hijacked函数"""

        def replacement_func1(x, y=10):
            return x * y

        def replacement_func2(x, y=10):
            return x ** y

        target1 = (self.container, 'original_func')
        target2 = (self.container, 'original_func')

        hijack_function(target1, replacement_func1)
        hijack_function(target1, replacement_func2)

        self.assertEqual(self.container.original_func(2, 3), 8)

        restore_all_hijacked()
        self.assertEqual(self.container.original_func(2, 3), 5)

    def test_global_get_original_function(self):
        """测试全局get_original_function函数"""

        def replacement_func(x, y=10):
            return x * y

        target = (self.container, 'original_func')

        original = get_original_function(target)
        self.assertIsNone(original)

        hijack_function(target, replacement_func)
        original = get_original_function(target)
        self.assertIsNotNone(original)
        self.assertEqual(original(2, 3), 5)

    def test_global_functions_with_multiple_targets(self):
        """测试多个目标的全局函数"""

        def replacement_func1(x, y=10):
            return x * y

        def replacement_func2(x, y=10):
            return x + y + 100

        class MultiContainer:
            @staticmethod
            def func1(x, y=10):
                return x + y

            @staticmethod
            def func2(x, y=10):
                return x - y

        container = MultiContainer()

        target1 = (container, 'func1')
        target2 = (container, 'func2')

        hijack_function(target1, replacement_func1)
        hijack_function(target2, replacement_func2)

        self.assertEqual(container.func1(2, 3), 6)
        self.assertEqual(container.func2(10, 3), 113)

        restore_all_hijacked()

        self.assertEqual(container.func1(2, 3), 5)
        self.assertEqual(container.func2(10, 3), 7)

    def test_global_functions_edge_cases(self):
        """测试全局函数的边界情况"""
        restore_all_hijacked()


class TestEdgeCases(unittest.TestCase):
    """测试边界情况"""

    def test_empty_container(self):
        """测试空容器"""

        class EmptyContainer:
            pass

        container = EmptyContainer()
        target = (container, 'nonexistent')

        def replacement():
            return "replaced"

        with self.assertRaises(AttributeError):
            hijack_function(target, replacement)

    def test_none_values_in_wrapper(self):
        """测试包装器中的None值处理"""

        def original_func(x=None):
            return x or 0

        def replacement_func(x=None):
            return (x or 0) + 10

        wrapper = _create_kwargs_wrapper(original_func, replacement_func)

        result = wrapper(x=None)
        self.assertEqual(result, 10)

        result = wrapper()
        self.assertEqual(result, 10)

    def test_function_with_no_parameters(self):
        """测试无参数函数"""

        def original_func():
            return 42

        def replacement_func():
            return 84

        wrapper = _create_kwargs_wrapper(original_func, replacement_func)

        result = wrapper()
        self.assertEqual(result, 84)

    def test_function_with_various_arg_types(self):
        """测试包含各种参数类型的函数"""

        def replacement_func(**kwargs):
            args_val = kwargs.get('args', ())
            kwargs_dict = kwargs.get('kwargs', {})

            if args_val:
                first_arg = args_val[0]
                kw_arg_val = kwargs_dict.get('kw_arg', 1)
                extra_val = kwargs_dict.get('extra', 0)

                return first_arg * kw_arg_val * (sum([x for x in args_val if isinstance(x, (int, float))]) + 1) * (
                            extra_val + 1)
            else:
                return 0

        wrapper = _create_kwargs_wrapper(original_func_with_args_kwargs, replacement_func)

        result = wrapper(2, kw_arg=3, extra=5)
        self.assertEqual(result, 108)

    def test_inspect_signature_with_complex_function(self):
        """测试复杂函数的签名检查"""
        wrapper = _create_kwargs_wrapper(
            complex_func_for_test, replacement_func_for_complex_test
        )

        result = wrapper(42, "test")
        self.assertEqual(result, "replaced-42-test")

        result = wrapper(42)
        self.assertEqual(result, "replaced-42-default")


if __name__ == '__main__':
    unittest.main()